Лабораторная работа 4: Семантическая сегментация с использованием PyTorch¶

Цели работы¶

Цель — разработать и обучить сверточную нейронную сеть для задачи мультиклассовой семантической сегментации изображений на наборе данных SUIM с использованием PyTorch.

Набор данных.¶

  1. Данные содержат 8 классов. Маска сегментации имеет вид трехканального изображения с пикселями, значения которых равно либо 0, либо 255, например, (0, 0, 0), (0, 0, 255) и так далее. Помимо этого встречаются и промежуточные значения, отличные от 0 и 255. В рамках данной лабораторной работы предлагается следующее преобразование: значения маски, меньшие 128, нужно установить в 0, а значения, равные или больше 128, установить в 255.

  2. Для упрощения работы рекомендуется объединить следующие классы в один:

  • класс 2 - Aquatic plants and sea-grass
  • класс 3 - Wrecks and ruins
  • класс 5 - Reefs and invertebrates
  • класс 7 - Sea-floor and rocks

Требования¶

  1. Необходимо выполнить и отобразить в Jupyter следующие задачи:

    • Загрузка и проверка данных. Загрузить и предобработать данные с демонстрацией избранных изображений и соответствующих масок, чтобы подтвердить корректность загрузки и соответствие размерностей данных.
    • Реализация архитектуры сети. Разработать архитектуру нейронной сети для сегментации с использованием фреймворка PyTorch.
    • Настройка гиперпараметров обучения. Настроить параметры обучения, такие как функция ошибки, размер сети, скорость обучения и другие параметры.
    • Тестирование модели. После завершения обучения для оценки качества работы необходимо посчитать accuracy, IoU и визуализировать confusion matrix (с нормализацией, normalize='true').
    • Визуализация результатов. После завершения обучения построить и отобразить результаты сегментации на тестовых изображениях, сравнивая с реальными масками сегментации.
  2. Выбор архитектуры:

  • Можно использовать или адаптировать известные архитектуры глубокого обучения.
  • Может быть полезным:
    • уменьшить количество параметров в нейронной сети и размер входного изображения для ускорения сходимости, предотвращения переобучения и ускорения работы нейронной сети.
    • использовать аугментацию данных и взвешенные/специализированные функции ошибки. При аугментации данных необходимо учитывать связь изображений с маской классов.
  • Использовать перенос знаний недопустимо.
In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import confusion_matrix
import seaborn as sea

import matplotlib.pyplot as plt
from torchsummary import summary
from PIL import Image
import numpy as np
import time

Загрузка и проверка корректности данных

In [ ]:
import os
import shutil

from google.colab import drive
drive.mount('/content/drive')

def copy_files_recursive(source_folder, destination_folder):
    for root, dirs, files in os.walk(source_folder):
        for file in files:
            source_path = os.path.join(root, file)
            destination_path = os.path.join(destination_folder, os.path.relpath(source_path, source_folder))

            os.makedirs(os.path.dirname(destination_path), exist_ok=True)

            shutil.copyfile(source_path, destination_path)
Mounted at /content/drive
In [ ]:
remote_root = '/content/drive/MyDrive/SUIM'
root = '/content/SUIM'
copy_files_recursive(remote_root, root)
In [ ]:
number_classes = 5

classes = {
    "background": [(0, 0, 0)],
    "human_divers": [(0, 0, 1)],
    "robots": [(1, 0, 0)],
    "fish_vertebrates": [(1, 1, 0)],
    "other": [(0, 1, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)]
}

color_classes = [[0, 0, 0],
                 [0, 0, 1],
                 [1, 0, 0],
                 [1, 1, 0],
                 [0, 1, 0]]
In [ ]:
class CustomDataset(Dataset):
    def __init__(self, images, masks):
        self.images = torch.tensor(images, dtype = torch.float)
        self.masks = torch.tensor(masks, dtype = torch.float)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        return image, mask


def load_dataset(root_images, root_masks, image_size):
    images = []
    list_dir = sorted(os.listdir(root_images))
    for file_name in list_dir:
        file_path = os.path.join(root_images, file_name)
        if os.path.isfile(file_path):
            with Image.open(file_path) as image:
                resized_image = np.array(image.resize(image_size)) / 255
                images.append(resized_image)

    labels = []
    list_dir = sorted(os.listdir(root_masks))
    for file_name in list_dir:
        file_path = os.path.join(root_masks, file_name)
        if os.path.isfile(file_path):
            with Image.open(file_path) as mask:
                background = np.zeros(image_size)
                human_divers = np.zeros(image_size)
                robots = np.zeros(image_size)
                fish_vertebrates = np.zeros(image_size)
                other = np.zeros(image_size)

                resized_mask = np.array(mask.resize(image_size)) / 255
                resized_mask = np.where(resized_mask < 0.5, 0, 1)

                for i in range(image_size[0]):
                    for j in range(image_size[1]):
                        if np.all(resized_mask[i, j] == classes["background"], axis = -1):
                            background[i, j] = 1
                        elif np.all(resized_mask[i, j] == classes["human_divers"], axis = -1):
                            human_divers[i, j] = 1
                        elif np.all(resized_mask[i, j] == classes["robots"], axis = -1):
                            robots[i, j] = 1
                        elif np.all(resized_mask[i, j] == classes["fish_vertebrates"], axis = -1):
                            fish_vertebrates[i, j] = 1
                        else:
                            other[i, j] = 1

                labels.append(np.stack([background, human_divers, robots, fish_vertebrates, other], -1))

    images = np.array(images)
    labels = np.array(labels)
    dataset = CustomDataset(images, labels)

    return dataset
In [ ]:
def dataset_info(dataset):
    print("Размер датасета изображений:", dataset.images.shape)
    print("Размер датасета масок:", dataset.masks.shape)
    print()

    number_pixels = {'Background': np.count_nonzero(dataset.masks[:, :, :, 0] == 1),
                     'Human divers': np.count_nonzero(dataset.masks[:, :, :, 1] == 1),
                     'Robots': np.count_nonzero(dataset.masks[:, :, :, 2] == 1),
                     'Fish and vertebrates': np.count_nonzero(dataset.masks[:, :, :, 3] == 1),
                     'Other': np.count_nonzero(dataset.masks[:, :, :, 4] == 1)}

    sum_pixel = dataset.images.shape[0] * dataset.images.shape[1] * dataset.images.shape[2]

    for key, value in number_pixels.items():
        print(f'Класс: {key}, Число пикселей: {value}({(value / sum_pixel * 100):.2f}%)')
In [ ]:
image_size = 80
In [ ]:
root_train_images = "/content/SUIM/train_val/images"
root_train_masks = "/content/SUIM/train_val/masks"
train_dataset = load_dataset(root_train_images, root_train_masks, (image_size, image_size))
In [ ]:
root_test_images = "/content/SUIM/TEST/images"
root_test_masks = "/content/SUIM/TEST/masks"
test_dataset = load_dataset(root_test_images, root_test_masks, (image_size, image_size))
In [ ]:
print('Train dataset:')
dataset_info(train_dataset)
print()
print('Test dataset:')
dataset_info(test_dataset)
Train dataset:
Размер датасета изображений: torch.Size([1525, 80, 80, 3])
Размер датасета масок: torch.Size([1525, 80, 80, 5])

Класс: Background, Число пикселей: 3034338(31.09%)
Класс: Human divers, Число пикселей: 184114(1.89%)
Класс: Robots, Число пикселей: 37740(0.39%)
Класс: Fish and vertebrates, Число пикселей: 767257(7.86%)
Класс: Other, Число пикселей: 5736551(58.78%)

Test dataset:
Размер датасета изображений: torch.Size([110, 80, 80, 3])
Размер датасета масок: torch.Size([110, 80, 80, 5])

Класс: Background, Число пикселей: 282598(40.14%)
Класс: Human divers, Число пикселей: 20661(2.93%)
Класс: Robots, Число пикселей: 4557(0.65%)
Класс: Fish and vertebrates, Число пикселей: 54083(7.68%)
Класс: Other, Число пикселей: 342101(48.59%)
In [ ]:
def plot_images_with_masks(dataset, title):
    vert_size = 6
    horiz_size = 3
    fig, axes = plt.subplots(vert_size, horiz_size * 2, figsize = (15, 15))
    fig.suptitle(title)

    mask_sizes = (image_size, image_size, 3)

    count_images = vert_size * horiz_size
    for number in range(count_images):
        i = number // horiz_size
        j = number % horiz_size

        image, mask = dataset[number]

        axes[i, j * 2].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 2].axis('off')

        rgb_mask = np.zeros(mask_sizes)
        for k in range(number_classes):
            rgb_mask[mask[:, :, k] > 0] = color_classes[k]

        axes[i, j * 2 + 1].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 2 + 1].imshow(rgb_mask, alpha = 0.35)
        axes[i, j * 2 + 1].axis('off')

    plt.tight_layout()
    plt.show()
In [ ]:
plot_images_with_masks(train_dataset, 'Examples from train dataset')
No description has been provided for this image
In [ ]:
plot_images_with_masks(test_dataset, 'Examples from test dataset')
No description has been provided for this image
In [ ]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')
Device: cuda

Эксперименты¶

Гиперпараметры обучения

In [ ]:
learning_rate = 0.01
epochs = 50
batch_size = 36

Разбиение датасетов на батчи

In [ ]:
training_dataset, validation_dataset = torch.utils.data.random_split(train_dataset, [0.85, 0.15])
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True)
validation_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
In [ ]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 'same'),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.25)
        )

    def forward(self, x):
        return self.double_conv(x)


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x):
        x = self.max_pool(x)
        return self.conv(x)


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim = 1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels = 3, n_classes = 5):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.input = DoubleConv(n_channels, 64)
        self.down1 = DownSample(64, 128)
        self.down2 = DownSample(128, 256)
        self.down3 = DownSample(256, 512)
        self.down4 = DownSample(512, 1024)

        self.up1 = UpSample(1024 + 512, 512)
        self.up2 = UpSample(512 + 256, 256)
        self.up3 = UpSample(256 + 128, 128)
        self.up4 = UpSample(128 + 64, 64)

        self.output = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.input(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        logits = self.output(x)
        return torch.sigmoid(logits)
In [ ]:
net = UNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr = learning_rate)
summary(net, (3, image_size, image_size))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 80, 80]           1,792
       BatchNorm2d-2           [-1, 64, 80, 80]             128
              ReLU-3           [-1, 64, 80, 80]               0
           Dropout-4           [-1, 64, 80, 80]               0
        DoubleConv-5           [-1, 64, 80, 80]               0
         MaxPool2d-6           [-1, 64, 40, 40]               0
            Conv2d-7          [-1, 128, 40, 40]          73,856
       BatchNorm2d-8          [-1, 128, 40, 40]             256
              ReLU-9          [-1, 128, 40, 40]               0
          Dropout-10          [-1, 128, 40, 40]               0
       DoubleConv-11          [-1, 128, 40, 40]               0
       DownSample-12          [-1, 128, 40, 40]               0
        MaxPool2d-13          [-1, 128, 20, 20]               0
           Conv2d-14          [-1, 256, 20, 20]         295,168
      BatchNorm2d-15          [-1, 256, 20, 20]             512
             ReLU-16          [-1, 256, 20, 20]               0
          Dropout-17          [-1, 256, 20, 20]               0
       DoubleConv-18          [-1, 256, 20, 20]               0
       DownSample-19          [-1, 256, 20, 20]               0
        MaxPool2d-20          [-1, 256, 10, 10]               0
           Conv2d-21          [-1, 512, 10, 10]       1,180,160
      BatchNorm2d-22          [-1, 512, 10, 10]           1,024
             ReLU-23          [-1, 512, 10, 10]               0
          Dropout-24          [-1, 512, 10, 10]               0
       DoubleConv-25          [-1, 512, 10, 10]               0
       DownSample-26          [-1, 512, 10, 10]               0
        MaxPool2d-27            [-1, 512, 5, 5]               0
           Conv2d-28           [-1, 1024, 5, 5]       4,719,616
      BatchNorm2d-29           [-1, 1024, 5, 5]           2,048
             ReLU-30           [-1, 1024, 5, 5]               0
          Dropout-31           [-1, 1024, 5, 5]               0
       DoubleConv-32           [-1, 1024, 5, 5]               0
       DownSample-33           [-1, 1024, 5, 5]               0
         Upsample-34         [-1, 1024, 10, 10]               0
           Conv2d-35          [-1, 512, 10, 10]       7,078,400
      BatchNorm2d-36          [-1, 512, 10, 10]           1,024
             ReLU-37          [-1, 512, 10, 10]               0
          Dropout-38          [-1, 512, 10, 10]               0
       DoubleConv-39          [-1, 512, 10, 10]               0
         UpSample-40          [-1, 512, 10, 10]               0
         Upsample-41          [-1, 512, 20, 20]               0
           Conv2d-42          [-1, 256, 20, 20]       1,769,728
      BatchNorm2d-43          [-1, 256, 20, 20]             512
             ReLU-44          [-1, 256, 20, 20]               0
          Dropout-45          [-1, 256, 20, 20]               0
       DoubleConv-46          [-1, 256, 20, 20]               0
         UpSample-47          [-1, 256, 20, 20]               0
         Upsample-48          [-1, 256, 40, 40]               0
           Conv2d-49          [-1, 128, 40, 40]         442,496
      BatchNorm2d-50          [-1, 128, 40, 40]             256
             ReLU-51          [-1, 128, 40, 40]               0
          Dropout-52          [-1, 128, 40, 40]               0
       DoubleConv-53          [-1, 128, 40, 40]               0
         UpSample-54          [-1, 128, 40, 40]               0
         Upsample-55          [-1, 128, 80, 80]               0
           Conv2d-56           [-1, 64, 80, 80]         110,656
      BatchNorm2d-57           [-1, 64, 80, 80]             128
             ReLU-58           [-1, 64, 80, 80]               0
          Dropout-59           [-1, 64, 80, 80]               0
       DoubleConv-60           [-1, 64, 80, 80]               0
         UpSample-61           [-1, 64, 80, 80]               0
           Conv2d-62            [-1, 5, 80, 80]             325
          OutConv-63            [-1, 5, 80, 80]               0
================================================================
Total params: 15,678,085
Trainable params: 15,678,085
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.07
Forward/backward pass size (MB): 82.03
Params size (MB): 59.81
Estimated Total Size (MB): 141.91
----------------------------------------------------------------
In [ ]:
def train(net, train_loader, validation_loader, criterion, epochs):
  for epoch in range(epochs):
      loss_list = []
      time_one = time.time()
      for data in train_loader:
          images = data[0].permute(0, 3, 1, 2).to(device)
          labels = data[1].permute(0, 3, 1, 2).to(device)

          outputs = net(images)
          loss = criterion(outputs, labels)

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          loss_list.append(loss)
      diff_time = time.time() - time_one

      loss_validation_list = []
      time_one = time.time()
      for data in validation_loader:
          images = data[0].permute(0, 3, 1, 2).to(device)
          labels = data[1].permute(0, 3, 1, 2).to(device)

          outputs = net(images)
          loss = criterion(outputs, labels)

          loss_validation_list.append(loss)
      diff_time_validation = time.time() - time_one

      print(f"Epoch: {epoch + 1}/{epochs}, Train "
            f"Loss: {torch.stack(loss_list).mean():.4f}, "
            f"Time: {diff_time:.2f} Validation "
            f"Loss: {torch.stack(loss_validation_list).mean():.4f}, "
            f"Time: {diff_time_validation:.2f}")
In [ ]:
train(net, training_loader, validation_loader, criterion, epochs)
Epoch: 1/50, Train Loss: 0.3104, Time: 6.52 Validation Loss: 0.2474, Time: 0.44
Epoch: 2/50, Train Loss: 0.2360, Time: 6.46 Validation Loss: 0.2362, Time: 0.44
Epoch: 3/50, Train Loss: 0.2245, Time: 6.49 Validation Loss: 0.2295, Time: 0.45
Epoch: 4/50, Train Loss: 0.2223, Time: 6.53 Validation Loss: 0.2209, Time: 0.45
Epoch: 5/50, Train Loss: 0.2197, Time: 6.56 Validation Loss: 0.2144, Time: 0.45
Epoch: 6/50, Train Loss: 0.2108, Time: 6.57 Validation Loss: 0.2093, Time: 0.45
Epoch: 7/50, Train Loss: 0.2107, Time: 6.57 Validation Loss: 0.2316, Time: 0.45
Epoch: 8/50, Train Loss: 0.2240, Time: 6.60 Validation Loss: 0.2081, Time: 0.45
Epoch: 9/50, Train Loss: 0.2034, Time: 6.63 Validation Loss: 0.2433, Time: 0.46
Epoch: 10/50, Train Loss: 0.2316, Time: 6.65 Validation Loss: 0.2114, Time: 0.46
Epoch: 11/50, Train Loss: 0.2001, Time: 6.66 Validation Loss: 0.2070, Time: 0.46
Epoch: 12/50, Train Loss: 0.1955, Time: 6.69 Validation Loss: 0.1991, Time: 0.46
Epoch: 13/50, Train Loss: 0.1875, Time: 6.70 Validation Loss: 0.1935, Time: 0.46
Epoch: 14/50, Train Loss: 0.2033, Time: 6.73 Validation Loss: 0.1912, Time: 0.47
Epoch: 15/50, Train Loss: 0.1925, Time: 6.74 Validation Loss: 0.1894, Time: 0.47
Epoch: 16/50, Train Loss: 0.1870, Time: 6.75 Validation Loss: 0.1891, Time: 0.46
Epoch: 17/50, Train Loss: 0.1854, Time: 6.76 Validation Loss: 0.1782, Time: 0.47
Epoch: 18/50, Train Loss: 0.1903, Time: 6.77 Validation Loss: 0.1854, Time: 0.47
Epoch: 19/50, Train Loss: 0.1804, Time: 6.78 Validation Loss: 0.1840, Time: 0.47
Epoch: 20/50, Train Loss: 0.1779, Time: 6.79 Validation Loss: 0.1780, Time: 0.47
Epoch: 21/50, Train Loss: 0.1723, Time: 6.81 Validation Loss: 0.1886, Time: 0.47
Epoch: 22/50, Train Loss: 0.1943, Time: 6.81 Validation Loss: 0.1729, Time: 0.47
Epoch: 23/50, Train Loss: 0.1789, Time: 6.82 Validation Loss: 0.1786, Time: 0.47
Epoch: 24/50, Train Loss: 0.1766, Time: 6.83 Validation Loss: 0.1799, Time: 0.47
Epoch: 25/50, Train Loss: 0.1621, Time: 6.84 Validation Loss: 0.1611, Time: 0.47
Epoch: 26/50, Train Loss: 0.1564, Time: 6.84 Validation Loss: 0.1699, Time: 0.47
Epoch: 27/50, Train Loss: 0.1520, Time: 6.85 Validation Loss: 0.1851, Time: 0.47
Epoch: 28/50, Train Loss: 0.1576, Time: 6.85 Validation Loss: 0.1693, Time: 0.47
Epoch: 29/50, Train Loss: 0.1522, Time: 6.87 Validation Loss: 0.1595, Time: 0.48
Epoch: 30/50, Train Loss: 0.1737, Time: 6.87 Validation Loss: 0.1658, Time: 0.48
Epoch: 31/50, Train Loss: 0.1665, Time: 6.87 Validation Loss: 0.1772, Time: 0.48
Epoch: 32/50, Train Loss: 0.1901, Time: 6.87 Validation Loss: 0.1819, Time: 0.48
Epoch: 33/50, Train Loss: 0.1714, Time: 6.88 Validation Loss: 0.1764, Time: 0.48
Epoch: 34/50, Train Loss: 0.1575, Time: 6.89 Validation Loss: 0.1655, Time: 0.47
Epoch: 35/50, Train Loss: 0.1525, Time: 6.89 Validation Loss: 0.1816, Time: 0.48
Epoch: 36/50, Train Loss: 0.1832, Time: 6.89 Validation Loss: 0.1725, Time: 0.48
Epoch: 37/50, Train Loss: 0.1644, Time: 6.89 Validation Loss: 0.1715, Time: 0.48
Epoch: 38/50, Train Loss: 0.1535, Time: 6.90 Validation Loss: 0.1592, Time: 0.48
Epoch: 39/50, Train Loss: 0.1496, Time: 6.89 Validation Loss: 0.1693, Time: 0.48
Epoch: 40/50, Train Loss: 0.1474, Time: 6.90 Validation Loss: 0.1576, Time: 0.48
Epoch: 41/50, Train Loss: 0.1414, Time: 6.90 Validation Loss: 0.1567, Time: 0.48
Epoch: 42/50, Train Loss: 0.1452, Time: 6.89 Validation Loss: 0.1592, Time: 0.48
Epoch: 43/50, Train Loss: 0.1335, Time: 6.89 Validation Loss: 0.1553, Time: 0.48
Epoch: 44/50, Train Loss: 0.1347, Time: 6.89 Validation Loss: 0.1624, Time: 0.48
Epoch: 45/50, Train Loss: 0.1383, Time: 6.90 Validation Loss: 0.1554, Time: 0.48
Epoch: 46/50, Train Loss: 0.1467, Time: 6.89 Validation Loss: 0.1548, Time: 0.48
Epoch: 47/50, Train Loss: 0.1238, Time: 6.91 Validation Loss: 0.1572, Time: 0.48
Epoch: 48/50, Train Loss: 0.1177, Time: 6.90 Validation Loss: 0.1543, Time: 0.48
Epoch: 49/50, Train Loss: 0.1138, Time: 6.90 Validation Loss: 0.1583, Time: 0.48
Epoch: 50/50, Train Loss: 0.1077, Time: 6.88 Validation Loss: 0.1564, Time: 0.48
In [ ]:
def IoU(labels, predict):
    intersection = np.logical_and(labels, predict)
    union = np.logical_or(labels, predict)
    if np.sum(union) == 0:
        iou_score = 0
    else:
        iou_score = np.sum(intersection) / np.sum(union)
    return iou_score

def metrics_compute(net, data_loader):
    accuracy_list, IoU_list = [], []
    cm_mask, cm_predict_mask = [], []

    with torch.no_grad():
        for images, masks in data_loader:
            images = images.permute(0, 3, 1, 2).to(device)
            masks = masks.permute(0, 3, 1, 2).numpy()

            predict_masks = net(images)
            predict_masks = torch.where(predict_masks < torch.tensor(0.5), torch.tensor(0), torch.tensor(1)).cpu().numpy()

            for k in range(5):
                masks_tmp = masks[:, k, :, :]
                predict_masks_tmp = predict_masks[:, k, :, :]

                masks_tmp = np.where(masks_tmp != 1, 0, k + 1)
                predict_masks_tmp = np.where(predict_masks_tmp != 1, 0, k + 1)
                masks[:, k, :, :] = masks_tmp
                predict_masks[:, k, :, :] = predict_masks_tmp

            cm_mask.append(masks)
            cm_predict_mask.append(predict_masks)

            temp_accucary, temp_iou = [], []
            for k in range(5):
                accuracy = np.mean(predict_masks[:, k, :, :] == masks[:, k, :, :])
                iou = IoU(masks[:, k, :, :], predict_masks[:, k, :, :])
                temp_accucary.append(accuracy)
                temp_iou.append(iou)

            accuracy_list.append(temp_accucary)
            IoU_list.append(temp_iou)

    accuracy_list = np.array(accuracy_list)
    IoU_list = np.array(IoU_list)

    print(f'Оценка Accuracy для каждого класса: {np.mean(accuracy_list, axis = 0, dtype = np.float16)}')
    print(f'Оценка IoU для каждого класса: {np.mean(IoU_list, axis = 0, dtype = np.float16)}')

    print(f'Оценка Accuracy на данных: {np.mean(accuracy_list, dtype = np.float16):.4f}')
    print(f'Оценка IoU на данных: {np.mean(IoU_list, dtype = np.float16):.4f}')

    cm_mask = np.concatenate(cm_mask, axis = 0)
    cm_predict_mask = np.concatenate(cm_predict_mask, axis = 0)
    cm_mask = cm_mask.flatten()
    cm_predict_mask = cm_predict_mask.flatten()

    name_class = ['Background', 'Human divers', 'Robots', 'Fish and vertebrates', 'Other']
    cm = confusion_matrix(cm_mask, cm_predict_mask, labels = np.arange(5), normalize = 'true')
    sea.heatmap(cm, annot = True, cmap = 'Blues', xticklabels = name_class, yticklabels = name_class)
    plt.xlabel('Предсказанные классы')
    plt.ylabel('Истинные классы')
    plt.title('Confusion matrix')
    plt.show()
In [91]:
metrics_compute(net, test_loader)
Оценка Accuracy для каждого класса: [0.911  0.979  0.995  0.9478 0.875 ]
Оценка IoU для каждого класса: [0.792  0.1936 0.0718 0.2764 0.786 ]
Оценка Accuracy на данных: 0.9414
Оценка IoU на данных: 0.4243
No description has been provided for this image
In [96]:
def plot_images_with_masks_test(net, dataset):
    vert_size = 12
    horiz_size = 2
    fig, axes = plt.subplots(vert_size, horiz_size * 3, figsize = (15, 25))
    fig.suptitle("Predicted vs. True")

    mask_sizes = (image_size, image_size, 3)

    count_images = vert_size * horiz_size
    for number in range(count_images):
        i = number // horiz_size
        j = number % horiz_size

        image, mask = dataset[number]

        axes[i, j * 3].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 3].set_title('Image', fontsize = 10)
        axes[i, j * 3].axis('off')

        with torch.no_grad():
            images = image
            images = images.unsqueeze(0)
            images = images.permute(0, 3, 1, 2).to(device)
            predict_mask = net(images)
            predict_mask = torch.where(predict_mask < torch.tensor(0.5), torch.tensor(0), torch.tensor(1)).permute(0, 2, 3, 1).cpu()

        rgb_predicted_mask = np.zeros(mask_sizes)
        for k in range(number_classes):
            rgb_predicted_mask[predict_mask[0, :, :, k] > 0] = color_classes[k]

        axes[i, j * 3 + 1].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 3 + 1].imshow(rgb_predicted_mask, alpha = 0.35)
        axes[i, j * 3 + 1].set_title('Predicted', fontsize = 10)
        axes[i, j * 3 + 1].axis('off')

        rgb_mask = np.zeros(mask_sizes)
        for k in range(number_classes):
            rgb_mask[mask[:, :, k] > 0] = color_classes[k]

        axes[i, j * 3 + 2].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 3 + 2].imshow(rgb_mask, alpha = 0.35)
        axes[i, j * 3 + 2].set_title('True', fontsize = 10)
        axes[i, j * 3 + 2].axis('off')

    plt.tight_layout()
    plt.show()
In [97]:
plot_images_with_masks_test(net, test_dataset)
No description has been provided for this image

Аугментация данных и перебалансировка классов

В тренировочных данных у нас большой перекос данных в сторону классов "Background" и "Other". Попробуем использовать аугментацию для увеличения количества пикселей маленьких классов и уберем картинки, на которых большая часть это "Background" и "Other"

In [82]:
def data_rebalancing(dataset, coeff):
    new_images = []
    new_masks = []

    mask_sizes = (image_size, image_size, 3)

    count_pixels = image_size * image_size

    for image, mask in dataset:
        rgb_mask = np.zeros(mask_sizes)
        for k in range(number_classes):
            rgb_mask[mask[:, :, k] > 0] = color_classes[k]

        count_colors = [0, 0, 0, 0, 0]
        for i in range(image_size):
            for j in range(image_size):
                for k in range(number_classes):
                    if np.all(rgb_mask[i, j] == color_classes[k], axis=-1):
                        count_colors[k] += 1

        back_other_colors = count_colors[0] + count_colors[4]
        small_classes_colors = count_pixels - back_other_colors

        if back_other_colors / count_pixels < 1 - coeff:
            print('Количество пикселей на класс: ', count_colors)

            new_images.append(image)
            new_masks.append(mask)

            down_image = np.flipud(image)
            down_mask = np.flipud(mask)
            new_images.append(down_image)
            new_masks.append(down_mask)

            right_image = np.fliplr(image)
            right_mask = np.fliplr(mask)
            new_images.append(right_image)
            new_masks.append(right_mask)

            right_down_image = np.flipud(right_image)
            right_down_mask = np.flipud(right_mask)
            new_images.append(right_down_image)
            new_masks.append(right_down_mask)

    new_images = np.array(new_images)
    new_masks = np.array(new_masks)

    new_dataset = CustomDataset(new_images, new_masks)
    return new_dataset
In [83]:
new_train_dataset = data_rebalancing(train_dataset, 0.2)
Количество пикселей на класс:  [4856, 690, 0, 764, 90]
Количество пикселей на класс:  [1290, 1321, 1371, 0, 2418]
Количество пикселей на класс:  [4886, 463, 1051, 0, 0]
Количество пикселей на класс:  [1576, 1815, 385, 0, 2624]
Количество пикселей на класс:  [1451, 146, 0, 1502, 3301]
Количество пикселей на класс:  [2011, 1674, 0, 1769, 946]
Количество пикселей на класс:  [4962, 1081, 357, 0, 0]
Количество пикселей на класс:  [5011, 403, 986, 0, 0]
Количество пикселей на класс:  [1295, 1459, 0, 2067, 1579]
Количество пикселей на класс:  [1543, 2165, 0, 0, 2692]
Количество пикселей на класс:  [4010, 0, 2390, 0, 0]
Количество пикселей на класс:  [2047, 1172, 127, 0, 3054]
Количество пикселей на класс:  [5016, 1384, 0, 0, 0]
Количество пикселей на класс:  [2446, 1694, 0, 0, 2260]
Количество пикселей на класс:  [4698, 1577, 124, 0, 1]
Количество пикселей на класс:  [1394, 1352, 0, 2578, 1076]
Количество пикселей на класс:  [2520, 1493, 0, 0, 2387]
Количество пикселей на класс:  [2620, 2186, 445, 0, 1149]
Количество пикселей на класс:  [5085, 508, 159, 648, 0]
Количество пикселей на класс:  [2467, 1828, 0, 0, 2105]
Количество пикселей на класс:  [4972, 749, 679, 0, 0]
Количество пикселей на класс:  [4856, 907, 637, 0, 0]
Количество пикселей на класс:  [4410, 1615, 375, 0, 0]
Количество пикселей на класс:  [1397, 1452, 0, 983, 2568]
Количество пикселей на класс:  [2498, 1366, 1, 0, 2535]
Количество пикселей на класс:  [2209, 1101, 337, 0, 2753]
Количество пикселей на класс:  [1091, 1296, 532, 0, 3481]
Количество пикселей на класс:  [1984, 1103, 266, 0, 3047]
Количество пикселей на класс:  [0, 1385, 0, 0, 5015]
Количество пикселей на класс:  [1497, 1460, 0, 1944, 1499]
Количество пикселей на класс:  [2774, 617, 73, 1634, 1302]
Количество пикселей на класс:  [5076, 847, 323, 154, 0]
Количество пикселей на класс:  [1917, 2149, 0, 0, 2334]
Количество пикселей на класс:  [2761, 667, 32, 1283, 1657]
Количество пикселей на класс:  [1583, 1649, 0, 1, 3167]
Количество пикселей на класс:  [656, 0, 0, 3583, 2161]
Количество пикселей на класс:  [2621, 0, 1, 1636, 2142]
Количество пикселей на класс:  [750, 0, 0, 1767, 3883]
Количество пикселей на класс:  [0, 0, 3, 1732, 4665]
Количество пикселей на класс:  [1288, 0, 5, 2769, 2338]
Количество пикселей на класс:  [1192, 1, 4, 1897, 3306]
Количество пикселей на класс:  [0, 0, 1, 4332, 2067]
Количество пикселей на класс:  [0, 0, 0, 1560, 4840]
Количество пикселей на класс:  [1590, 0, 0, 3469, 1341]
Количество пикселей на класс:  [4608, 0, 0, 1420, 372]
Количество пикселей на класс:  [270, 0, 5, 3212, 2913]
Количество пикселей на класс:  [2621, 0, 0, 3376, 403]
Количество пикселей на класс:  [2354, 0, 1, 1964, 2081]
Количество пикселей на класс:  [4690, 0, 0, 1710, 0]
Количество пикселей на класс:  [776, 0, 0, 5451, 173]
Количество пикселей на класс:  [0, 0, 1, 1614, 4785]
Количество пикселей на класс:  [4002, 0, 0, 1727, 671]
Количество пикселей на класс:  [2650, 0, 13, 2130, 1607]
Количество пикселей на класс:  [435, 0, 2, 1290, 4673]
Количество пикселей на класс:  [932, 0, 4, 1450, 4014]
Количество пикселей на класс:  [404, 0, 5, 3041, 2950]
Количество пикселей на класс:  [3609, 0, 0, 2143, 648]
Количество пикселей на класс:  [2903, 0, 0, 2942, 555]
Количество пикселей на класс:  [720, 0, 4, 4113, 1563]
Количество пикселей на класс:  [2839, 0, 0, 1732, 1829]
Количество пикселей на класс:  [405, 0, 1, 4187, 1807]
Количество пикселей на класс:  [2895, 62, 5, 1624, 1814]
Количество пикселей на класс:  [3045, 0, 6, 1433, 1916]
Количество пикселей на класс:  [3142, 0, 0, 1434, 1824]
Количество пикселей на класс:  [0, 0, 3, 1702, 4695]
Количество пикселей на класс:  [3078, 0, 0, 1886, 1436]
Количество пикселей на класс:  [0, 0, 1, 1308, 5091]
Количество пикселей на класс:  [2298, 1, 3, 1564, 2534]
Количество пикселей на класс:  [395, 0, 2, 1541, 4462]
Количество пикселей на класс:  [2802, 0, 0, 1577, 2021]
Количество пикселей на класс:  [1404, 0, 1, 1865, 3130]
Количество пикселей на класс:  [3004, 1, 0, 2329, 1066]
Количество пикселей на класс:  [4851, 0, 0, 1549, 0]
Количество пикселей на класс:  [432, 0, 1, 2364, 3603]
Количество пикселей на класс:  [1388, 0, 16, 2875, 2121]
Количество пикселей на класс:  [0, 0, 0, 1430, 4970]
Количество пикселей на класс:  [0, 0, 0, 1533, 4867]
Количество пикселей на класс:  [0, 0, 0, 2063, 4337]
Количество пикселей на класс:  [0, 0, 1, 1959, 4440]
Количество пикселей на класс:  [0, 0, 0, 1656, 4744]
Количество пикселей на класс:  [3067, 1, 5, 2488, 839]
Количество пикселей на класс:  [4580, 0, 0, 1820, 0]
Количество пикселей на класс:  [1851, 4, 0, 2502, 2043]
Количество пикселей на класс:  [0, 0, 2, 2604, 3794]
Количество пикселей на класс:  [0, 0, 0, 1658, 4742]
Количество пикселей на класс:  [1061, 0, 10, 2821, 2508]
Количество пикселей на класс:  [131, 0, 0, 2636, 3633]
Количество пикселей на класс:  [3971, 0, 0, 1490, 939]
Количество пикселей на класс:  [4968, 0, 0, 1432, 0]
Количество пикселей на класс:  [1688, 0, 0, 2404, 2308]
Количество пикселей на класс:  [2749, 0, 11, 1995, 1645]
Количество пикселей на класс:  [0, 0, 0, 2152, 4248]
Количество пикселей на класс:  [3271, 2, 0, 1727, 1400]
Количество пикселей на класс:  [2408, 0, 4, 2246, 1742]
Количество пикселей на класс:  [1658, 0, 0, 1318, 3424]
Количество пикселей на класс:  [0, 0, 1, 1928, 4471]
Количество пикселей на класс:  [1048, 0, 18, 1274, 4060]
Количество пикселей на класс:  [2474, 0, 0, 2235, 1691]
Количество пикселей на класс:  [2453, 0, 0, 1815, 2132]
Количество пикселей на класс:  [4869, 0, 0, 1531, 0]
Количество пикселей на класс:  [0, 0, 0, 1476, 4924]
Количество пикселей на класс:  [4602, 0, 0, 1798, 0]
Количество пикселей на класс:  [668, 0, 0, 1865, 3867]
Количество пикселей на класс:  [0, 0, 0, 1289, 5111]
Количество пикселей на класс:  [4025, 0, 0, 1766, 609]
Количество пикселей на класс:  [0, 0, 3, 1554, 4843]
Количество пикселей на класс:  [650, 0, 21, 3817, 1912]
Количество пикселей на класс:  [0, 0, 0, 1406, 4994]
Количество пикселей на класс:  [1035, 1, 3, 3154, 2207]
Количество пикселей на класс:  [1013, 0, 4, 1285, 4098]
Количество пикселей на класс:  [3577, 0, 0, 2141, 682]
Количество пикселей на класс:  [0, 0, 0, 1420, 4980]
Количество пикселей на класс:  [896, 0, 0, 3546, 1958]
Количество пикселей на класс:  [2049, 0, 5, 1382, 2964]
Количество пикселей на класс:  [0, 0, 0, 2029, 4371]
Количество пикселей на класс:  [0, 0, 0, 1804, 4596]
Количество пикселей на класс:  [2822, 0, 6, 1897, 1675]
Количество пикселей на класс:  [0, 0, 0, 2898, 3502]
Количество пикселей на класс:  [1369, 0, 2, 1448, 3581]
Количество пикселей на класс:  [2147, 0, 3, 3609, 641]
Количество пикселей на класс:  [0, 0, 3, 1622, 4775]
Количество пикселей на класс:  [80, 0, 0, 1659, 4661]
Количество пикселей на класс:  [0, 0, 2, 1795, 4603]
Количество пикселей на класс:  [0, 0, 0, 1298, 5102]
Количество пикселей на класс:  [0, 0, 2, 2104, 4294]
Количество пикселей на класс:  [1748, 0, 5, 1926, 2721]
Количество пикселей на класс:  [0, 0, 1, 2683, 3716]
Количество пикселей на класс:  [5041, 0, 0, 1359, 0]
Количество пикселей на класс:  [830, 0, 3, 1727, 3840]
Количество пикселей на класс:  [0, 0, 1, 2048, 4351]
Количество пикселей на класс:  [2443, 1, 4, 1535, 2417]
Количество пикселей на класс:  [4783, 0, 0, 1617, 0]
Количество пикселей на класс:  [497, 1, 3, 2782, 3117]
Количество пикселей на класс:  [3262, 1, 4, 1459, 1674]
Количество пикселей на класс:  [1336, 0, 6, 2786, 2272]
Количество пикселей на класс:  [0, 0, 1, 1495, 4904]
Количество пикселей на класс:  [0, 0, 1, 1782, 4617]
Количество пикселей на класс:  [2204, 0, 2, 2222, 1972]
Количество пикселей на класс:  [12, 0, 4, 2218, 4166]
Количество пикселей на класс:  [4986, 0, 0, 1414, 0]
Количество пикселей на класс:  [3355, 0, 10, 1350, 1685]
Количество пикселей на класс:  [0, 0, 1, 1357, 5042]
Количество пикселей на класс:  [2354, 0, 0, 1503, 2543]
Количество пикселей на класс:  [992, 0, 0, 1325, 4083]
Количество пикселей на класс:  [4932, 0, 0, 1468, 0]
Количество пикселей на класс:  [52, 0, 4, 2861, 3483]
Количество пикселей на класс:  [2489, 0, 13, 1838, 2060]
Количество пикселей на класс:  [0, 1585, 258, 0, 4557]
Количество пикселей на класс:  [1404, 0, 5, 3767, 1224]
Количество пикселей на класс:  [2751, 0, 0, 3649, 0]
Количество пикселей на класс:  [1772, 1, 1, 2791, 1835]
Количество пикселей на класс:  [0, 0, 2, 3043, 3355]
Количество пикселей на класс:  [1243, 0, 5, 2318, 2834]
Количество пикселей на класс:  [1425, 0, 2, 3150, 1823]
Количество пикселей на класс:  [0, 0, 0, 2203, 4197]
Количество пикселей на класс:  [3591, 0, 0, 2106, 703]
Количество пикселей на класс:  [0, 0, 3, 1530, 4867]
Количество пикселей на класс:  [0, 0, 8, 3853, 2539]
Количество пикселей на класс:  [2330, 60, 3, 2093, 1914]
Количество пикселей на класс:  [0, 0, 0, 1362, 5038]
Количество пикселей на класс:  [4402, 0, 0, 1998, 0]
Количество пикселей на класс:  [0, 0, 3, 1679, 4718]
Количество пикселей на класс:  [2178, 80, 6, 2605, 1531]
Количество пикселей на класс:  [87, 1, 77, 1765, 4470]
Количество пикселей на класс:  [29, 0, 38, 1668, 4665]
Количество пикселей на класс:  [1283, 0, 0, 5117, 0]
Количество пикселей на класс:  [62, 0, 0, 2176, 4162]
Количество пикселей на класс:  [0, 0, 0, 1661, 4739]
Количество пикселей на класс:  [1226, 0, 1, 3125, 2048]
Количество пикселей на класс:  [1830, 0, 0, 2147, 2423]
Количество пикселей на класс:  [2114, 0, 6, 1758, 2522]
Количество пикселей на класс:  [4456, 0, 0, 1944, 0]
Количество пикселей на класс:  [1497, 0, 23, 2296, 2584]
Количество пикселей на класс:  [709, 2909, 3, 637, 2142]
Количество пикселей на класс:  [1066, 0, 50, 1597, 3687]
Количество пикселей на класс:  [541, 1, 50, 1425, 4383]
Количество пикселей на класс:  [622, 1, 20, 1721, 4036]
Количество пикселей на класс:  [0, 0, 2, 1876, 4522]
Количество пикселей на класс:  [7, 0, 22, 1516, 4855]
Количество пикселей на класс:  [2865, 0, 1, 1601, 1933]
Количество пикселей на класс:  [2515, 52, 0, 2485, 1348]
Количество пикселей на класс:  [1549, 0, 8, 1630, 3213]
Количество пикселей на класс:  [0, 0, 1, 1308, 5091]
Количество пикселей на класс:  [4371, 0, 3, 1285, 741]
Количество пикселей на класс:  [0, 0, 0, 1725, 4675]
Количество пикселей на класс:  [4077, 0, 0, 2323, 0]
Количество пикселей на класс:  [1561, 297, 3, 2863, 1676]
Количество пикселей на класс:  [1703, 0, 11, 1957, 2729]
Количество пикселей на класс:  [0, 0, 7, 2341, 4052]
Количество пикселей на класс:  [2680, 0, 3, 2061, 1656]
Количество пикселей на класс:  [3275, 0, 5, 2583, 537]
Количество пикселей на класс:  [2117, 0, 5, 2638, 1640]
Количество пикселей на класс:  [0, 0, 7, 1390, 5003]
Количество пикселей на класс:  [0, 0, 2, 1965, 4433]
Количество пикселей на класс:  [4764, 0, 0, 1636, 0]
Количество пикселей на класс:  [383, 0, 0, 1328, 4689]
Количество пикселей на класс:  [0, 0, 3, 3670, 2727]
Количество пикселей на класс:  [4026, 0, 3, 1608, 763]
Количество пикселей на класс:  [4137, 0, 0, 2263, 0]
Количество пикселей на класс:  [0, 0, 4, 1412, 4984]
Количество пикселей на класс:  [3347, 0, 0, 1295, 1758]
Количество пикселей на класс:  [4860, 0, 0, 1540, 0]
Количество пикселей на класс:  [4021, 0, 0, 2379, 0]
Количество пикселей на класс:  [0, 0, 0, 2423, 3977]
Количество пикселей на класс:  [1616, 0, 25, 1256, 3503]
Количество пикселей на класс:  [743, 0, 0, 2654, 3003]
Количество пикселей на класс:  [1585, 0, 17, 1292, 3506]
Количество пикселей на класс:  [3975, 0, 0, 2425, 0]
Количество пикселей на класс:  [2152, 1, 8, 1613, 2626]
Количество пикселей на класс:  [1362, 0, 5, 4385, 648]
Количество пикселей на класс:  [842, 2, 25, 2003, 3528]
Количество пикселей на класс:  [4855, 0, 0, 1545, 0]
Количество пикселей на класс:  [870, 3, 8, 3975, 1544]
Количество пикселей на класс:  [1564, 0, 1, 3415, 1420]
Количество пикселей на класс:  [4338, 0, 1, 1541, 520]
Количество пикселей на класс:  [0, 0, 0, 1830, 4570]
Количество пикселей на класс:  [4857, 0, 0, 1543, 0]
Количество пикселей на класс:  [4533, 1553, 314, 0, 0]
Количество пикселей на класс:  [3560, 2037, 803, 0, 0]
Количество пикселей на класс:  [0, 0, 9, 2423, 3968]
Количество пикселей на класс:  [1598, 1661, 0, 0, 3141]
Количество пикселей на класс:  [0, 0, 0, 3006, 3394]
Количество пикселей на класс:  [0, 0, 0, 1636, 4764]
Количество пикселей на класс:  [2104, 1424, 0, 0, 2872]
Количество пикселей на класс:  [4429, 0, 0, 1971, 0]
Количество пикселей на класс:  [5035, 1365, 0, 0, 0]
Количество пикселей на класс:  [270, 0, 0, 1313, 4817]
Количество пикселей на класс:  [0, 0, 0, 1942, 4458]
Количество пикселей на класс:  [4463, 1937, 0, 0, 0]
Количество пикселей на класс:  [4770, 1630, 0, 0, 0]
Количество пикселей на класс:  [3976, 0, 0, 2424, 0]
Количество пикселей на класс:  [0, 0, 0, 2877, 3523]
Количество пикселей на класс:  [0, 0, 0, 2172, 4228]
Количество пикселей на класс:  [2155, 72, 1, 2658, 1514]
Количество пикселей на класс:  [0, 0, 1, 2262, 4137]
Количество пикселей на класс:  [1089, 0, 0, 1342, 3969]
Количество пикселей на класс:  [1623, 30, 0, 2738, 2009]
Количество пикселей на класс:  [1593, 1485, 0, 0, 3322]
Количество пикселей на класс:  [2103, 1711, 0, 0, 2586]
In [84]:
print('New train dataset:')
dataset_info(new_train_dataset)
New train dataset:
Размер датасета изображений: torch.Size([956, 80, 80, 3])
Размер датасета масок: torch.Size([956, 80, 80, 5])

Класс: Background, Число пикселей: 1762760(28.81%)
Класс: Human divers, Число пикселей: 250992(4.10%)
Класс: Robots, Число пикселей: 51116(0.84%)
Класс: Fish and vertebrates, Число пикселей: 1698740(27.76%)
Класс: Other, Число пикселей: 2354792(38.49%)

Настройка гиперпараметров

In [85]:
learning_rate = 0.01
epochs = 50
batch_size = 36
In [86]:
new_training_dataset, new_validation_dataset = torch.utils.data.random_split(new_train_dataset, [0.85, 0.15])
new_training_loader = DataLoader(new_training_dataset, batch_size = batch_size, shuffle = True)
new_validation_loader = DataLoader(new_validation_dataset, batch_size = batch_size, shuffle = True)
In [87]:
new_net = UNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(new_net.parameters(), lr = learning_rate)
In [88]:
train(new_net, new_training_loader, new_validation_loader, criterion, epochs)
Epoch: 1/50, Train Loss: 0.4051, Time: 3.98 Validation Loss: 0.3252, Time: 0.31
Epoch: 2/50, Train Loss: 0.3262, Time: 3.91 Validation Loss: 0.3164, Time: 0.32
Epoch: 3/50, Train Loss: 0.3095, Time: 3.92 Validation Loss: 0.3135, Time: 0.32
Epoch: 4/50, Train Loss: 0.2874, Time: 3.93 Validation Loss: 0.2976, Time: 0.32
Epoch: 5/50, Train Loss: 0.2923, Time: 3.94 Validation Loss: 0.2854, Time: 0.32
Epoch: 6/50, Train Loss: 0.2801, Time: 3.95 Validation Loss: 0.2828, Time: 0.31
Epoch: 7/50, Train Loss: 0.2661, Time: 3.96 Validation Loss: 0.2777, Time: 0.32
Epoch: 8/50, Train Loss: 0.2621, Time: 3.96 Validation Loss: 0.2631, Time: 0.32
Epoch: 9/50, Train Loss: 0.2582, Time: 3.98 Validation Loss: 0.2560, Time: 0.32
Epoch: 10/50, Train Loss: 0.2515, Time: 3.99 Validation Loss: 0.2709, Time: 0.32
Epoch: 11/50, Train Loss: 0.2450, Time: 3.99 Validation Loss: 0.2595, Time: 0.32
Epoch: 12/50, Train Loss: 0.2384, Time: 4.01 Validation Loss: 0.2493, Time: 0.33
Epoch: 13/50, Train Loss: 0.2275, Time: 4.02 Validation Loss: 0.2274, Time: 0.32
Epoch: 14/50, Train Loss: 0.2209, Time: 4.02 Validation Loss: 0.2287, Time: 0.32
Epoch: 15/50, Train Loss: 0.2092, Time: 4.02 Validation Loss: 0.2352, Time: 0.32
Epoch: 16/50, Train Loss: 0.2240, Time: 4.02 Validation Loss: 0.2297, Time: 0.32
Epoch: 17/50, Train Loss: 0.2151, Time: 4.03 Validation Loss: 0.2269, Time: 0.33
Epoch: 18/50, Train Loss: 0.2068, Time: 4.05 Validation Loss: 0.2320, Time: 0.32
Epoch: 19/50, Train Loss: 0.2054, Time: 4.06 Validation Loss: 0.2132, Time: 0.33
Epoch: 20/50, Train Loss: 0.2023, Time: 4.06 Validation Loss: 0.2086, Time: 0.33
Epoch: 21/50, Train Loss: 0.1865, Time: 4.07 Validation Loss: 0.2156, Time: 0.33
Epoch: 22/50, Train Loss: 0.1779, Time: 4.08 Validation Loss: 0.1953, Time: 0.33
Epoch: 23/50, Train Loss: 0.1712, Time: 4.08 Validation Loss: 0.2005, Time: 0.33
Epoch: 24/50, Train Loss: 0.1640, Time: 4.08 Validation Loss: 0.1952, Time: 0.33
Epoch: 25/50, Train Loss: 0.1796, Time: 4.09 Validation Loss: 0.1955, Time: 0.33
Epoch: 26/50, Train Loss: 0.1732, Time: 4.10 Validation Loss: 0.1884, Time: 0.33
Epoch: 27/50, Train Loss: 0.1541, Time: 4.11 Validation Loss: 0.1936, Time: 0.33
Epoch: 28/50, Train Loss: 0.1448, Time: 4.12 Validation Loss: 0.1927, Time: 0.33
Epoch: 29/50, Train Loss: 0.1436, Time: 4.12 Validation Loss: 0.1847, Time: 0.33
Epoch: 30/50, Train Loss: 0.1468, Time: 4.12 Validation Loss: 0.2135, Time: 0.33
Epoch: 31/50, Train Loss: 0.1509, Time: 4.13 Validation Loss: 0.2036, Time: 0.33
Epoch: 32/50, Train Loss: 0.1393, Time: 4.12 Validation Loss: 0.1720, Time: 0.34
Epoch: 33/50, Train Loss: 0.1299, Time: 4.13 Validation Loss: 0.1649, Time: 0.33
Epoch: 34/50, Train Loss: 0.1243, Time: 4.14 Validation Loss: 0.1686, Time: 0.33
Epoch: 35/50, Train Loss: 0.1244, Time: 4.14 Validation Loss: 0.1680, Time: 0.33
Epoch: 36/50, Train Loss: 0.1179, Time: 4.15 Validation Loss: 0.1728, Time: 0.34
Epoch: 37/50, Train Loss: 0.1123, Time: 4.15 Validation Loss: 0.1638, Time: 0.34
Epoch: 38/50, Train Loss: 0.1132, Time: 4.14 Validation Loss: 0.1599, Time: 0.33
Epoch: 39/50, Train Loss: 0.1108, Time: 4.15 Validation Loss: 0.1600, Time: 0.33
Epoch: 40/50, Train Loss: 0.1058, Time: 4.16 Validation Loss: 0.1451, Time: 0.34
Epoch: 41/50, Train Loss: 0.1041, Time: 4.16 Validation Loss: 0.1619, Time: 0.34
Epoch: 42/50, Train Loss: 0.1025, Time: 4.16 Validation Loss: 0.1629, Time: 0.33
Epoch: 43/50, Train Loss: 0.1012, Time: 4.17 Validation Loss: 0.1505, Time: 0.34
Epoch: 44/50, Train Loss: 0.1080, Time: 4.16 Validation Loss: 0.1594, Time: 0.34
Epoch: 45/50, Train Loss: 0.1173, Time: 4.16 Validation Loss: 0.1796, Time: 0.34
Epoch: 46/50, Train Loss: 0.1135, Time: 4.17 Validation Loss: 0.1653, Time: 0.33
Epoch: 47/50, Train Loss: 0.0982, Time: 4.17 Validation Loss: 0.1450, Time: 0.34
Epoch: 48/50, Train Loss: 0.0888, Time: 4.16 Validation Loss: 0.1393, Time: 0.34
Epoch: 49/50, Train Loss: 0.0876, Time: 4.17 Validation Loss: 0.1342, Time: 0.34
Epoch: 50/50, Train Loss: 0.0832, Time: 4.17 Validation Loss: 0.1302, Time: 0.34
In [95]:
metrics_compute(new_net, test_loader)
Оценка Accuracy для каждого класса: [0.8496 0.9673 0.9946 0.7993 0.7217]
Оценка IoU для каждого класса: [0.6587  0.1621  0.11926 0.1731  0.5244 ]
Оценка Accuracy на данных: 0.8662
Оценка IoU на данных: 0.3276
No description has been provided for this image
In [98]:
plot_images_with_masks_test(new_net, test_dataset)
No description has been provided for this image